{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.7.9 64-bit ('attended-pooling-v2': conda)",
   "metadata": {
    "interpreter": {
     "hash": "4be85f8989b472e81eec2a9f752c4651e901496527f53c28cc3d62caf3c00114"
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datetime import datetime\n",
    "import random\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn \n",
    "from torch import optim\n",
    "from torch.utils.data import DataLoader\n",
    "from sklearn import metrics\n",
    "\n",
    "import layers\n",
    "import models\n",
    "import custom_loss\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from data_preprocessing import DrugDataset, DrugDataLoader, TOTAL_ATOM_FEATS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "55"
      ]
     },
     "metadata": {},
     "execution_count": 3
    }
   ],
   "source": [
    "TOTAL_ATOM_FEATS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ddi_train = pd.read_csv('data/ddi_training.csv')\n",
    "df_ddi_val = pd.read_csv('data/ddi_validation.csv')\n",
    "df_ddi_test = pd.read_csv('data/ddi_test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_tup = [(h, t, r) for h, t, r in zip(df_ddi_train['d1'], df_ddi_train['d2'], df_ddi_train['type'])]\n",
    "val_tup = [(h, t, r) for h, t, r in zip(df_ddi_val['d1'], df_ddi_val['d2'], df_ddi_val['type'])]\n",
    "test_tup = [(h, t, r) for h, t, r in zip(df_ddi_test['d1'], df_ddi_test['d2'], df_ddi_test['type'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "115185"
      ]
     },
     "metadata": {},
     "execution_count": 6
    }
   ],
   "source": [
    "len(train_tup)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "38348"
      ]
     },
     "metadata": {},
     "execution_count": 7
    }
   ],
   "source": [
    "len(val_tup)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "38337"
      ]
     },
     "metadata": {},
     "execution_count": 8
    }
   ],
   "source": [
    "len(test_tup)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "(0.6003283473184969, 0.19980716109866056, 0.19986449158284256)"
      ]
     },
     "metadata": {},
     "execution_count": 9
    }
   ],
   "source": [
    "total = len(val_tup) + len(train_tup) + len(test_tup)\n",
    "len(train_tup) / total, len(test_tup)/total, len(val_tup)/total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "n_atom_feats = TOTAL_ATOM_FEATS\n",
    "n_atom_hid = 64\n",
    "rel_total = 86\n",
    "lr = 1e-2\n",
    "weight_decay = 5e-4\n",
    "n_epochs = 300\n",
    "neg_samples = 1\n",
    "batch_size = 1024\n",
    "data_size_ratio = 1\n",
    "kge_dim = 64\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = DrugDataset(train_tup, ratio=data_size_ratio, neg_ent=neg_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_data = DrugDataset(val_tup, ratio=data_size_ratio, disjoint_split=False)\n",
    "test_data = DrugDataset(test_tup, disjoint_split=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Training with 115185 samples, validating with 38348, and testing with 38337\n"
     ]
    }
   ],
   "source": [
    "print(f\"Training with {len(train_data)} samples, validating with {len(val_data)}, and testing with {len(test_data)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data_loader = DrugDataLoader(train_data, batch_size=batch_size, shuffle=True)\n",
    "val_data_loader = DrugDataLoader(val_data, batch_size=batch_size *3)\n",
    "test_data_loader = DrugDataLoader(test_data, batch_size=batch_size *3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def do_compute(batch, device, training=True):\n",
    "        '''\n",
    "            *batch: (pos_tri, neg_tri)\n",
    "            *pos/neg_tri: (batch_h, batch_t, batch_r)\n",
    "        '''\n",
    "        probas_pred, ground_truth = [], []\n",
    "        pos_tri, neg_tri = batch\n",
    "        \n",
    "        pos_tri = [tensor.to(device=device) for tensor in pos_tri]\n",
    "        p_score = model(pos_tri)\n",
    "        probas_pred.append(torch.sigmoid(p_score.detach()).cpu())\n",
    "        ground_truth.append(np.ones(len(p_score)))\n",
    "\n",
    "        neg_tri = [tensor.to(device=device) for tensor in neg_tri]\n",
    "        n_score = model(neg_tri)\n",
    "        probas_pred.append(torch.sigmoid(n_score.detach()).cpu())\n",
    "        ground_truth.append(np.zeros(len(n_score)))\n",
    "\n",
    "        probas_pred = np.concatenate(probas_pred)\n",
    "        ground_truth = np.concatenate(ground_truth)\n",
    "\n",
    "        return p_score, n_score, probas_pred, ground_truth\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def do_compute_metrics(probas_pred, target):\n",
    "\n",
    "    pred = (probas_pred >= 0.5).astype(np.int)\n",
    "\n",
    "    acc = metrics.accuracy_score(target, pred)\n",
    "    auc_roc = metrics.roc_auc_score(target, probas_pred)\n",
    "    f1_score = metrics.f1_score(target, pred)\n",
    "\n",
    "    p, r, t = metrics.precision_recall_curve(target, probas_pred)\n",
    "    auc_prc = metrics.auc(r, p)\n",
    "\n",
    "    return acc, auc_roc, auc_prc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, train_data_loader, val_data_loader, loss_fn,  optimizer, n_epochs, device, scheduler=None):\n",
    "    print('Starting training at', datetime.today())\n",
    "    for i in range(1, n_epochs+1):\n",
    "        start = time.time()\n",
    "        train_loss = 0\n",
    "        train_loss_pos = 0\n",
    "        train_loss_neg = 0\n",
    "        val_loss = 0\n",
    "        val_loss_pos = 0\n",
    "        val_loss_neg = 0\n",
    "        train_probas_pred = []\n",
    "        train_ground_truth = []\n",
    "        val_probas_pred = []\n",
    "        val_ground_truth = []\n",
    "\n",
    "        for batch in train_data_loader:\n",
    "            model.train()\n",
    "            p_score, n_score, probas_pred, ground_truth = do_compute(batch, device)\n",
    "            train_probas_pred.append(probas_pred)\n",
    "            train_ground_truth.append(ground_truth)\n",
    "            loss, loss_p, loss_n = loss_fn(p_score, n_score)\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            train_loss += loss.item() * len(p_score)\n",
    "        train_loss /= len(train_data)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            train_probas_pred = np.concatenate(train_probas_pred)\n",
    "            train_ground_truth = np.concatenate(train_ground_truth)\n",
    "\n",
    "            train_acc, train_auc_roc, train_auc_prc = do_compute_metrics(train_probas_pred, train_ground_truth)\n",
    "\n",
    "            for batch in val_data_loader:\n",
    "                model.eval()\n",
    "                p_score, n_score, probas_pred, ground_truth = do_compute(batch, device)\n",
    "                val_probas_pred.append(probas_pred)\n",
    "                val_ground_truth.append(ground_truth)\n",
    "                loss, loss_p, loss_n = loss_fn(p_score, n_score)\n",
    "                val_loss += loss.item() * len(p_score)            \n",
    "\n",
    "            val_loss /= len(val_data)\n",
    "            val_probas_pred = np.concatenate(val_probas_pred)\n",
    "            val_ground_truth = np.concatenate(val_ground_truth)\n",
    "            val_acc, val_auc_roc, val_auc_prc = do_compute_metrics(val_probas_pred, val_ground_truth)\n",
    "               \n",
    "        if scheduler:\n",
    "            print('scheduling')\n",
    "            scheduler.step()\n",
    "\n",
    "\n",
    "        print(f'Epoch: {i} ({time.time() - start:.4f}s), train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f},'\n",
    "        f' train_acc: {train_acc:.4f}, val_acc:{val_acc:.4f}')\n",
    "        print(f'\\t\\ttrain_roc: {train_auc_roc:.4f}, val_roc: {val_auc_roc:.4f}, train_auprc: {train_auc_prc:.4f}, val_auprc: {val_auc_prc:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "SSI_DDI(\n",
       "  (initial_norm): LayerNorm(55)\n",
       "  (net_norms): ModuleList(\n",
       "    (0): LayerNorm(64)\n",
       "    (1): LayerNorm(64)\n",
       "    (2): LayerNorm(64)\n",
       "    (3): LayerNorm(64)\n",
       "  )\n",
       "  (block0): SSI_DDI_Block(\n",
       "    (conv): GATConv(55, 32, heads=2)\n",
       "    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1)\n",
       "  )\n",
       "  (block1): SSI_DDI_Block(\n",
       "    (conv): GATConv(64, 32, heads=2)\n",
       "    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1)\n",
       "  )\n",
       "  (block2): SSI_DDI_Block(\n",
       "    (conv): GATConv(64, 32, heads=2)\n",
       "    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1)\n",
       "  )\n",
       "  (block3): SSI_DDI_Block(\n",
       "    (conv): GATConv(64, 32, heads=2)\n",
       "    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1)\n",
       "  )\n",
       "  (co_attention): CoAttentionLayer()\n",
       "  (KGE): RESCAL(86, torch.Size([86, 4096]))\n",
       ")"
      ]
     },
     "metadata": {},
     "execution_count": 18
    }
   ],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "model = models.SSI_DDI(n_atom_feats, n_atom_hid, kge_dim, rel_total, heads_out_feat_params=[32, 32, 32, 32], blocks_params=[2, 2, 2, 2])\n",
    "loss = custom_loss.SigmoidLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
    "scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.96 ** (epoch))\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model.to(device=device);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "tags": [
     "outputPrepend"
    ]
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "ss: 0.4736, val_loss: 0.4750, train_acc: 0.7712, val_acc:0.7714\n",
      "\t\ttrain_roc: 0.8506, val_roc: 0.8497, train_auprc: 0.8303, val_auprc: 0.8305\n",
      "scheduling\n",
      "Epoch: 12 (56.1362s), train_loss: 0.4580, val_loss: 0.4600, train_acc: 0.7818, val_acc:0.7807\n",
      "\t\ttrain_roc: 0.8612, val_roc: 0.8615, train_auprc: 0.8416, val_auprc: 0.8438\n",
      "scheduling\n",
      "Epoch: 13 (54.5477s), train_loss: 0.4470, val_loss: 0.4521, train_acc: 0.7888, val_acc:0.7863\n",
      "\t\ttrain_roc: 0.8685, val_roc: 0.8662, train_auprc: 0.8493, val_auprc: 0.8471\n",
      "scheduling\n",
      "Epoch: 14 (53.8652s), train_loss: 0.4397, val_loss: 0.4519, train_acc: 0.7948, val_acc:0.7853\n",
      "\t\ttrain_roc: 0.8734, val_roc: 0.8663, train_auprc: 0.8537, val_auprc: 0.8445\n",
      "scheduling\n",
      "Epoch: 15 (59.7325s), train_loss: 0.4283, val_loss: 0.4822, train_acc: 0.8018, val_acc:0.7563\n",
      "\t\ttrain_roc: 0.8806, val_roc: 0.8657, train_auprc: 0.8619, val_auprc: 0.8499\n",
      "scheduling\n",
      "Epoch: 16 (56.8572s), train_loss: 0.4327, val_loss: 0.4286, train_acc: 0.7996, val_acc:0.8034\n",
      "\t\ttrain_roc: 0.8782, val_roc: 0.8808, train_auprc: 0.8593, val_auprc: 0.8599\n",
      "scheduling\n",
      "Epoch: 17 (57.6143s), train_loss: 0.4113, val_loss: 0.4116, train_acc: 0.8123, val_acc:0.8138\n",
      "\t\ttrain_roc: 0.8902, val_roc: 0.8912, train_auprc: 0.8718, val_auprc: 0.8756\n",
      "scheduling\n",
      "Epoch: 18 (60.8719s), train_loss: 0.4016, val_loss: 0.4102, train_acc: 0.8198, val_acc:0.8144\n",
      "\t\ttrain_roc: 0.8963, val_roc: 0.8926, train_auprc: 0.8780, val_auprc: 0.8758\n",
      "scheduling\n",
      "Epoch: 19 (70.4472s), train_loss: 0.3939, val_loss: 0.3997, train_acc: 0.8235, val_acc:0.8211\n",
      "\t\ttrain_roc: 0.9000, val_roc: 0.8972, train_auprc: 0.8831, val_auprc: 0.8795\n",
      "scheduling\n",
      "Epoch: 20 (80.8767s), train_loss: 0.3870, val_loss: 0.3961, train_acc: 0.8278, val_acc:0.8192\n",
      "\t\ttrain_roc: 0.9036, val_roc: 0.8998, train_auprc: 0.8863, val_auprc: 0.8849\n",
      "scheduling\n",
      "Epoch: 21 (77.5738s), train_loss: 0.3787, val_loss: 0.3889, train_acc: 0.8331, val_acc:0.8236\n",
      "\t\ttrain_roc: 0.9084, val_roc: 0.9046, train_auprc: 0.8917, val_auprc: 0.8880\n",
      "scheduling\n",
      "Epoch: 22 (78.1308s), train_loss: 0.3716, val_loss: 0.3776, train_acc: 0.8366, val_acc:0.8331\n",
      "\t\ttrain_roc: 0.9115, val_roc: 0.9087, train_auprc: 0.8952, val_auprc: 0.8931\n",
      "scheduling\n",
      "Epoch: 23 (77.4256s), train_loss: 0.3640, val_loss: 0.3751, train_acc: 0.8423, val_acc:0.8297\n",
      "\t\ttrain_roc: 0.9155, val_roc: 0.9128, train_auprc: 0.8996, val_auprc: 0.8986\n",
      "scheduling\n",
      "Epoch: 24 (77.4585s), train_loss: 0.3535, val_loss: 0.3696, train_acc: 0.8480, val_acc:0.8417\n",
      "\t\ttrain_roc: 0.9206, val_roc: 0.9144, train_auprc: 0.9047, val_auprc: 0.8976\n",
      "scheduling\n",
      "Epoch: 25 (77.1571s), train_loss: 0.3507, val_loss: 0.3606, train_acc: 0.8503, val_acc:0.8430\n",
      "\t\ttrain_roc: 0.9214, val_roc: 0.9179, train_auprc: 0.9056, val_auprc: 0.9026\n",
      "scheduling\n",
      "Epoch: 26 (80.2267s), train_loss: 0.3414, val_loss: 0.3621, train_acc: 0.8552, val_acc:0.8383\n",
      "\t\ttrain_roc: 0.9257, val_roc: 0.9209, train_auprc: 0.9106, val_auprc: 0.9070\n",
      "scheduling\n",
      "Epoch: 27 (84.3864s), train_loss: 0.3335, val_loss: 0.3472, train_acc: 0.8585, val_acc:0.8491\n",
      "\t\ttrain_roc: 0.9292, val_roc: 0.9238, train_auprc: 0.9148, val_auprc: 0.9098\n",
      "scheduling\n",
      "Epoch: 28 (82.1514s), train_loss: 0.3265, val_loss: 0.3420, train_acc: 0.8629, val_acc:0.8553\n",
      "\t\ttrain_roc: 0.9321, val_roc: 0.9257, train_auprc: 0.9183, val_auprc: 0.9111\n",
      "scheduling\n",
      "Epoch: 29 (77.2925s), train_loss: 0.3184, val_loss: 0.3377, train_acc: 0.8675, val_acc:0.8566\n",
      "\t\ttrain_roc: 0.9355, val_roc: 0.9273, train_auprc: 0.9213, val_auprc: 0.9137\n",
      "scheduling\n",
      "Epoch: 30 (77.8426s), train_loss: 0.3171, val_loss: 0.3246, train_acc: 0.8677, val_acc:0.8620\n",
      "\t\ttrain_roc: 0.9357, val_roc: 0.9333, train_auprc: 0.9221, val_auprc: 0.9219\n",
      "scheduling\n",
      "Epoch: 31 (77.1690s), train_loss: 0.3072, val_loss: 0.3275, train_acc: 0.8734, val_acc:0.8618\n",
      "\t\ttrain_roc: 0.9397, val_roc: 0.9316, train_auprc: 0.9268, val_auprc: 0.9192\n",
      "scheduling\n",
      "Epoch: 32 (76.6501s), train_loss: 0.3060, val_loss: 0.3208, train_acc: 0.8737, val_acc:0.8666\n",
      "\t\ttrain_roc: 0.9401, val_roc: 0.9349, train_auprc: 0.9270, val_auprc: 0.9230\n",
      "scheduling\n",
      "Epoch: 33 (75.5525s), train_loss: 0.3003, val_loss: 0.3146, train_acc: 0.8774, val_acc:0.8699\n",
      "\t\ttrain_roc: 0.9424, val_roc: 0.9387, train_auprc: 0.9294, val_auprc: 0.9271\n",
      "scheduling\n",
      "Epoch: 34 (75.2574s), train_loss: 0.2961, val_loss: 0.3142, train_acc: 0.8797, val_acc:0.8714\n",
      "\t\ttrain_roc: 0.9439, val_roc: 0.9379, train_auprc: 0.9310, val_auprc: 0.9256\n",
      "scheduling\n",
      "Epoch: 35 (75.2868s), train_loss: 0.2898, val_loss: 0.3060, train_acc: 0.8832, val_acc:0.8742\n",
      "\t\ttrain_roc: 0.9462, val_roc: 0.9405, train_auprc: 0.9336, val_auprc: 0.9287\n",
      "scheduling\n",
      "Epoch: 36 (76.8871s), train_loss: 0.2877, val_loss: 0.3091, train_acc: 0.8838, val_acc:0.8713\n",
      "\t\ttrain_roc: 0.9470, val_roc: 0.9396, train_auprc: 0.9348, val_auprc: 0.9279\n",
      "scheduling\n",
      "Epoch: 37 (72.7200s), train_loss: 0.2812, val_loss: 0.2953, train_acc: 0.8872, val_acc:0.8816\n",
      "\t\ttrain_roc: 0.9492, val_roc: 0.9446, train_auprc: 0.9373, val_auprc: 0.9334\n",
      "scheduling\n",
      "Epoch: 38 (72.5640s), train_loss: 0.2725, val_loss: 0.2953, train_acc: 0.8920, val_acc:0.8800\n",
      "\t\ttrain_roc: 0.9522, val_roc: 0.9452, train_auprc: 0.9408, val_auprc: 0.9345\n",
      "scheduling\n",
      "Epoch: 39 (73.0780s), train_loss: 0.2713, val_loss: 0.2883, train_acc: 0.8925, val_acc:0.8828\n",
      "\t\ttrain_roc: 0.9524, val_roc: 0.9474, train_auprc: 0.9411, val_auprc: 0.9375\n",
      "scheduling\n",
      "Epoch: 40 (73.2240s), train_loss: 0.2698, val_loss: 0.2951, train_acc: 0.8930, val_acc:0.8801\n",
      "\t\ttrain_roc: 0.9530, val_roc: 0.9456, train_auprc: 0.9416, val_auprc: 0.9354\n",
      "scheduling\n",
      "Epoch: 41 (72.7760s), train_loss: 0.2692, val_loss: 0.2821, train_acc: 0.8930, val_acc:0.8869\n",
      "\t\ttrain_roc: 0.9532, val_roc: 0.9489, train_auprc: 0.9421, val_auprc: 0.9382\n",
      "scheduling\n",
      "Epoch: 42 (72.7810s), train_loss: 0.2592, val_loss: 0.2864, train_acc: 0.8991, val_acc:0.8836\n",
      "\t\ttrain_roc: 0.9566, val_roc: 0.9493, train_auprc: 0.9458, val_auprc: 0.9378\n",
      "scheduling\n",
      "Epoch: 43 (72.7500s), train_loss: 0.2560, val_loss: 0.2781, train_acc: 0.8996, val_acc:0.8893\n",
      "\t\ttrain_roc: 0.9575, val_roc: 0.9508, train_auprc: 0.9472, val_auprc: 0.9406\n",
      "scheduling\n",
      "Epoch: 44 (72.3830s), train_loss: 0.2522, val_loss: 0.2837, train_acc: 0.9015, val_acc:0.8855\n",
      "\t\ttrain_roc: 0.9585, val_roc: 0.9502, train_auprc: 0.9485, val_auprc: 0.9398\n",
      "scheduling\n",
      "Epoch: 45 (72.6080s), train_loss: 0.2508, val_loss: 0.2709, train_acc: 0.9023, val_acc:0.8929\n",
      "\t\ttrain_roc: 0.9589, val_roc: 0.9535, train_auprc: 0.9488, val_auprc: 0.9447\n",
      "scheduling\n",
      "Epoch: 46 (72.4090s), train_loss: 0.2471, val_loss: 0.2686, train_acc: 0.9038, val_acc:0.8942\n",
      "\t\ttrain_roc: 0.9602, val_roc: 0.9542, train_auprc: 0.9503, val_auprc: 0.9444\n",
      "scheduling\n",
      "Epoch: 47 (72.9270s), train_loss: 0.2446, val_loss: 0.2758, train_acc: 0.9054, val_acc:0.8898\n",
      "\t\ttrain_roc: 0.9611, val_roc: 0.9520, train_auprc: 0.9512, val_auprc: 0.9423\n",
      "scheduling\n",
      "Epoch: 48 (72.3560s), train_loss: 0.2426, val_loss: 0.2710, train_acc: 0.9075, val_acc:0.8918\n",
      "\t\ttrain_roc: 0.9615, val_roc: 0.9546, train_auprc: 0.9513, val_auprc: 0.9457\n",
      "scheduling\n",
      "Epoch: 49 (72.5260s), train_loss: 0.2370, val_loss: 0.2625, train_acc: 0.9098, val_acc:0.8982\n",
      "\t\ttrain_roc: 0.9632, val_roc: 0.9565, train_auprc: 0.9536, val_auprc: 0.9473\n",
      "scheduling\n",
      "Epoch: 50 (72.4780s), train_loss: 0.2331, val_loss: 0.2624, train_acc: 0.9117, val_acc:0.8971\n",
      "\t\ttrain_roc: 0.9643, val_roc: 0.9562, train_auprc: 0.9549, val_auprc: 0.9473\n",
      "scheduling\n",
      "Epoch: 51 (72.5200s), train_loss: 0.2346, val_loss: 0.2574, train_acc: 0.9113, val_acc:0.8990\n",
      "\t\ttrain_roc: 0.9637, val_roc: 0.9575, train_auprc: 0.9538, val_auprc: 0.9491\n",
      "scheduling\n",
      "Epoch: 52 (72.3160s), train_loss: 0.2306, val_loss: 0.2572, train_acc: 0.9128, val_acc:0.8997\n",
      "\t\ttrain_roc: 0.9648, val_roc: 0.9592, train_auprc: 0.9554, val_auprc: 0.9507\n",
      "scheduling\n",
      "Epoch: 53 (72.4380s), train_loss: 0.2263, val_loss: 0.2516, train_acc: 0.9143, val_acc:0.9018\n",
      "\t\ttrain_roc: 0.9662, val_roc: 0.9593, train_auprc: 0.9572, val_auprc: 0.9513\n",
      "scheduling\n",
      "Epoch: 54 (72.4800s), train_loss: 0.2231, val_loss: 0.2527, train_acc: 0.9158, val_acc:0.9030\n",
      "\t\ttrain_roc: 0.9671, val_roc: 0.9593, train_auprc: 0.9584, val_auprc: 0.9506\n",
      "scheduling\n",
      "Epoch: 55 (72.8400s), train_loss: 0.2251, val_loss: 0.2508, train_acc: 0.9150, val_acc:0.9026\n",
      "\t\ttrain_roc: 0.9664, val_roc: 0.9606, train_auprc: 0.9571, val_auprc: 0.9532\n",
      "scheduling\n",
      "Epoch: 56 (72.4990s), train_loss: 0.2185, val_loss: 0.2465, train_acc: 0.9183, val_acc:0.9050\n",
      "\t\ttrain_roc: 0.9682, val_roc: 0.9615, train_auprc: 0.9600, val_auprc: 0.9539\n",
      "scheduling\n",
      "Epoch: 57 (72.2590s), train_loss: 0.2185, val_loss: 0.2511, train_acc: 0.9187, val_acc:0.9032\n",
      "\t\ttrain_roc: 0.9683, val_roc: 0.9599, train_auprc: 0.9599, val_auprc: 0.9512\n",
      "scheduling\n",
      "Epoch: 58 (72.2550s), train_loss: 0.2148, val_loss: 0.2492, train_acc: 0.9205, val_acc:0.9028\n",
      "\t\ttrain_roc: 0.9693, val_roc: 0.9616, train_auprc: 0.9611, val_auprc: 0.9538\n",
      "scheduling\n",
      "Epoch: 59 (72.4230s), train_loss: 0.2135, val_loss: 0.2465, train_acc: 0.9213, val_acc:0.9054\n",
      "\t\ttrain_roc: 0.9695, val_roc: 0.9612, train_auprc: 0.9611, val_auprc: 0.9528\n",
      "scheduling\n",
      "Epoch: 60 (72.6720s), train_loss: 0.2117, val_loss: 0.2429, train_acc: 0.9216, val_acc:0.9087\n",
      "\t\ttrain_roc: 0.9700, val_roc: 0.9622, train_auprc: 0.9619, val_auprc: 0.9537\n",
      "scheduling\n",
      "Epoch: 61 (72.4400s), train_loss: 0.2118, val_loss: 0.2418, train_acc: 0.9208, val_acc:0.9080\n",
      "\t\ttrain_roc: 0.9701, val_roc: 0.9629, train_auprc: 0.9621, val_auprc: 0.9552\n",
      "scheduling\n",
      "Epoch: 62 (72.5130s), train_loss: 0.2105, val_loss: 0.2393, train_acc: 0.9219, val_acc:0.9086\n",
      "\t\ttrain_roc: 0.9704, val_roc: 0.9639, train_auprc: 0.9623, val_auprc: 0.9568\n",
      "scheduling\n",
      "Epoch: 63 (72.4500s), train_loss: 0.2076, val_loss: 0.2405, train_acc: 0.9235, val_acc:0.9073\n",
      "\t\ttrain_roc: 0.9711, val_roc: 0.9646, train_auprc: 0.9635, val_auprc: 0.9577\n",
      "scheduling\n",
      "Epoch: 64 (72.6990s), train_loss: 0.2060, val_loss: 0.2418, train_acc: 0.9241, val_acc:0.9073\n",
      "\t\ttrain_roc: 0.9717, val_roc: 0.9629, train_auprc: 0.9641, val_auprc: 0.9560\n",
      "scheduling\n",
      "Epoch: 65 (60.5422s), train_loss: 0.2046, val_loss: 0.2395, train_acc: 0.9252, val_acc:0.9104\n",
      "\t\ttrain_roc: 0.9719, val_roc: 0.9642, train_auprc: 0.9642, val_auprc: 0.9570\n",
      "scheduling\n",
      "Epoch: 66 (53.1812s), train_loss: 0.2025, val_loss: 0.2368, train_acc: 0.9256, val_acc:0.9106\n",
      "\t\ttrain_roc: 0.9724, val_roc: 0.9644, train_auprc: 0.9651, val_auprc: 0.9573\n",
      "scheduling\n",
      "Epoch: 67 (53.4170s), train_loss: 0.2014, val_loss: 0.2333, train_acc: 0.9263, val_acc:0.9117\n",
      "\t\ttrain_roc: 0.9727, val_roc: 0.9650, train_auprc: 0.9649, val_auprc: 0.9589\n",
      "scheduling\n",
      "Epoch: 68 (53.1052s), train_loss: 0.2001, val_loss: 0.2348, train_acc: 0.9264, val_acc:0.9108\n",
      "\t\ttrain_roc: 0.9730, val_roc: 0.9651, train_auprc: 0.9657, val_auprc: 0.9584\n",
      "scheduling\n",
      "Epoch: 69 (53.6590s), train_loss: 0.1987, val_loss: 0.2355, train_acc: 0.9275, val_acc:0.9107\n",
      "\t\ttrain_roc: 0.9732, val_roc: 0.9656, train_auprc: 0.9657, val_auprc: 0.9593\n",
      "scheduling\n",
      "Epoch: 70 (52.6317s), train_loss: 0.1980, val_loss: 0.2321, train_acc: 0.9284, val_acc:0.9120\n",
      "\t\ttrain_roc: 0.9735, val_roc: 0.9662, train_auprc: 0.9657, val_auprc: 0.9599\n",
      "scheduling\n",
      "Epoch: 71 (53.1830s), train_loss: 0.1951, val_loss: 0.2343, train_acc: 0.9294, val_acc:0.9106\n",
      "\t\ttrain_roc: 0.9743, val_roc: 0.9657, train_auprc: 0.9672, val_auprc: 0.9596\n",
      "scheduling\n",
      "Epoch: 72 (52.1800s), train_loss: 0.1948, val_loss: 0.2278, train_acc: 0.9294, val_acc:0.9145\n",
      "\t\ttrain_roc: 0.9743, val_roc: 0.9669, train_auprc: 0.9672, val_auprc: 0.9607\n",
      "scheduling\n",
      "Epoch: 73 (60.1409s), train_loss: 0.1938, val_loss: 0.2349, train_acc: 0.9296, val_acc:0.9105\n",
      "\t\ttrain_roc: 0.9743, val_roc: 0.9659, train_auprc: 0.9671, val_auprc: 0.9592\n",
      "scheduling\n",
      "Epoch: 74 (58.1091s), train_loss: 0.1913, val_loss: 0.2312, train_acc: 0.9307, val_acc:0.9130\n",
      "\t\ttrain_roc: 0.9752, val_roc: 0.9664, train_auprc: 0.9682, val_auprc: 0.9600\n",
      "scheduling\n",
      "Epoch: 75 (54.7017s), train_loss: 0.1915, val_loss: 0.2318, train_acc: 0.9308, val_acc:0.9132\n",
      "\t\ttrain_roc: 0.9749, val_roc: 0.9655, train_auprc: 0.9679, val_auprc: 0.9585\n",
      "scheduling\n",
      "Epoch: 76 (56.2158s), train_loss: 0.1900, val_loss: 0.2298, train_acc: 0.9311, val_acc:0.9137\n",
      "\t\ttrain_roc: 0.9752, val_roc: 0.9664, train_auprc: 0.9685, val_auprc: 0.9600\n",
      "scheduling\n",
      "Epoch: 77 (74.1254s), train_loss: 0.1896, val_loss: 0.2295, train_acc: 0.9311, val_acc:0.9152\n",
      "\t\ttrain_roc: 0.9755, val_roc: 0.9666, train_auprc: 0.9687, val_auprc: 0.9589\n",
      "scheduling\n",
      "Epoch: 78 (75.6095s), train_loss: 0.1875, val_loss: 0.2247, train_acc: 0.9323, val_acc:0.9162\n",
      "\t\ttrain_roc: 0.9762, val_roc: 0.9679, train_auprc: 0.9696, val_auprc: 0.9615\n",
      "scheduling\n",
      "Epoch: 79 (74.8380s), train_loss: 0.1876, val_loss: 0.2326, train_acc: 0.9327, val_acc:0.9129\n",
      "\t\ttrain_roc: 0.9757, val_roc: 0.9677, train_auprc: 0.9688, val_auprc: 0.9613\n",
      "scheduling\n",
      "Epoch: 80 (73.6014s), train_loss: 0.1863, val_loss: 0.2219, train_acc: 0.9325, val_acc:0.9178\n",
      "\t\ttrain_roc: 0.9762, val_roc: 0.9686, train_auprc: 0.9697, val_auprc: 0.9626\n",
      "scheduling\n",
      "Epoch: 81 (74.3363s), train_loss: 0.1874, val_loss: 0.2226, train_acc: 0.9328, val_acc:0.9167\n",
      "\t\ttrain_roc: 0.9758, val_roc: 0.9686, train_auprc: 0.9691, val_auprc: 0.9626\n",
      "scheduling\n",
      "Epoch: 82 (76.9084s), train_loss: 0.1858, val_loss: 0.2243, train_acc: 0.9332, val_acc:0.9157\n",
      "\t\ttrain_roc: 0.9764, val_roc: 0.9684, train_auprc: 0.9698, val_auprc: 0.9628\n",
      "scheduling\n",
      "Epoch: 83 (69.8200s), train_loss: 0.1834, val_loss: 0.2213, train_acc: 0.9342, val_acc:0.9184\n",
      "\t\ttrain_roc: 0.9769, val_roc: 0.9691, train_auprc: 0.9705, val_auprc: 0.9634\n",
      "scheduling\n",
      "Epoch: 84 (56.6463s), train_loss: 0.1829, val_loss: 0.2188, train_acc: 0.9347, val_acc:0.9188\n",
      "\t\ttrain_roc: 0.9769, val_roc: 0.9697, train_auprc: 0.9704, val_auprc: 0.9642\n",
      "scheduling\n",
      "Epoch: 85 (53.7245s), train_loss: 0.1827, val_loss: 0.2219, train_acc: 0.9346, val_acc:0.9175\n",
      "\t\ttrain_roc: 0.9770, val_roc: 0.9687, train_auprc: 0.9704, val_auprc: 0.9628\n",
      "scheduling\n",
      "Epoch: 86 (53.7540s), train_loss: 0.1815, val_loss: 0.2221, train_acc: 0.9351, val_acc:0.9173\n",
      "\t\ttrain_roc: 0.9773, val_roc: 0.9685, train_auprc: 0.9709, val_auprc: 0.9622\n",
      "scheduling\n",
      "Epoch: 87 (58.1392s), train_loss: 0.1814, val_loss: 0.2225, train_acc: 0.9350, val_acc:0.9175\n",
      "\t\ttrain_roc: 0.9772, val_roc: 0.9688, train_auprc: 0.9708, val_auprc: 0.9627\n",
      "scheduling\n",
      "Epoch: 88 (57.2961s), train_loss: 0.1813, val_loss: 0.2242, train_acc: 0.9357, val_acc:0.9169\n",
      "\t\ttrain_roc: 0.9771, val_roc: 0.9683, train_auprc: 0.9706, val_auprc: 0.9620\n",
      "scheduling\n",
      "Epoch: 89 (53.9114s), train_loss: 0.1793, val_loss: 0.2200, train_acc: 0.9362, val_acc:0.9185\n",
      "\t\ttrain_roc: 0.9780, val_roc: 0.9693, train_auprc: 0.9717, val_auprc: 0.9639\n",
      "scheduling\n",
      "Epoch: 90 (53.8550s), train_loss: 0.1787, val_loss: 0.2182, train_acc: 0.9364, val_acc:0.9203\n",
      "\t\ttrain_roc: 0.9780, val_roc: 0.9696, train_auprc: 0.9719, val_auprc: 0.9637\n",
      "scheduling\n",
      "Epoch: 91 (67.1382s), train_loss: 0.1789, val_loss: 0.2215, train_acc: 0.9363, val_acc:0.9179\n",
      "\t\ttrain_roc: 0.9778, val_roc: 0.9689, train_auprc: 0.9715, val_auprc: 0.9632\n",
      "scheduling\n",
      "Epoch: 92 (67.5178s), train_loss: 0.1777, val_loss: 0.2198, train_acc: 0.9372, val_acc:0.9197\n",
      "\t\ttrain_roc: 0.9781, val_roc: 0.9696, train_auprc: 0.9718, val_auprc: 0.9635\n",
      "scheduling\n",
      "Epoch: 93 (56.5784s), train_loss: 0.1769, val_loss: 0.2217, train_acc: 0.9377, val_acc:0.9174\n",
      "\t\ttrain_roc: 0.9781, val_roc: 0.9688, train_auprc: 0.9716, val_auprc: 0.9626\n",
      "scheduling\n",
      "Epoch: 94 (53.1232s), train_loss: 0.1775, val_loss: 0.2218, train_acc: 0.9372, val_acc:0.9184\n",
      "\t\ttrain_roc: 0.9780, val_roc: 0.9686, train_auprc: 0.9717, val_auprc: 0.9623\n",
      "scheduling\n",
      "Epoch: 95 (54.0346s), train_loss: 0.1771, val_loss: 0.2222, train_acc: 0.9370, val_acc:0.9188\n",
      "\t\ttrain_roc: 0.9783, val_roc: 0.9693, train_auprc: 0.9721, val_auprc: 0.9630\n",
      "scheduling\n",
      "Epoch: 96 (52.9832s), train_loss: 0.1763, val_loss: 0.2211, train_acc: 0.9375, val_acc:0.9181\n",
      "\t\ttrain_roc: 0.9786, val_roc: 0.9696, train_auprc: 0.9728, val_auprc: 0.9636\n",
      "scheduling\n",
      "Epoch: 97 (53.4439s), train_loss: 0.1755, val_loss: 0.2200, train_acc: 0.9382, val_acc:0.9189\n",
      "\t\ttrain_roc: 0.9786, val_roc: 0.9695, train_auprc: 0.9726, val_auprc: 0.9632\n",
      "scheduling\n",
      "Epoch: 98 (62.3852s), train_loss: 0.1741, val_loss: 0.2185, train_acc: 0.9381, val_acc:0.9191\n",
      "\t\ttrain_roc: 0.9790, val_roc: 0.9699, train_auprc: 0.9731, val_auprc: 0.9644\n",
      "scheduling\n",
      "Epoch: 99 (53.5809s), train_loss: 0.1756, val_loss: 0.2160, train_acc: 0.9384, val_acc:0.9209\n",
      "\t\ttrain_roc: 0.9786, val_roc: 0.9704, train_auprc: 0.9722, val_auprc: 0.9649\n",
      "scheduling\n",
      "Epoch: 100 (52.4450s), train_loss: 0.1746, val_loss: 0.2154, train_acc: 0.9382, val_acc:0.9205\n",
      "\t\ttrain_roc: 0.9789, val_roc: 0.9709, train_auprc: 0.9729, val_auprc: 0.9656\n",
      "scheduling\n",
      "Epoch: 101 (52.4110s), train_loss: 0.1735, val_loss: 0.2209, train_acc: 0.9393, val_acc:0.9192\n",
      "\t\ttrain_roc: 0.9790, val_roc: 0.9697, train_auprc: 0.9728, val_auprc: 0.9633\n",
      "scheduling\n",
      "Epoch: 102 (53.0397s), train_loss: 0.1721, val_loss: 0.2167, train_acc: 0.9397, val_acc:0.9203\n",
      "\t\ttrain_roc: 0.9792, val_roc: 0.9703, train_auprc: 0.9731, val_auprc: 0.9649\n",
      "scheduling\n",
      "Epoch: 103 (53.2260s), train_loss: 0.1724, val_loss: 0.2183, train_acc: 0.9392, val_acc:0.9204\n",
      "\t\ttrain_roc: 0.9794, val_roc: 0.9701, train_auprc: 0.9735, val_auprc: 0.9640\n",
      "scheduling\n",
      "Epoch: 104 (54.6688s), train_loss: 0.1721, val_loss: 0.2185, train_acc: 0.9391, val_acc:0.9191\n",
      "\t\ttrain_roc: 0.9794, val_roc: 0.9702, train_auprc: 0.9737, val_auprc: 0.9651\n",
      "scheduling\n",
      "Epoch: 105 (53.4410s), train_loss: 0.1720, val_loss: 0.2182, train_acc: 0.9391, val_acc:0.9202\n",
      "\t\ttrain_roc: 0.9795, val_roc: 0.9702, train_auprc: 0.9732, val_auprc: 0.9647\n",
      "scheduling\n",
      "Epoch: 106 (56.3148s), train_loss: 0.1722, val_loss: 0.2173, train_acc: 0.9391, val_acc:0.9204\n",
      "\t\ttrain_roc: 0.9793, val_roc: 0.9705, train_auprc: 0.9734, val_auprc: 0.9649\n",
      "scheduling\n",
      "Epoch: 107 (57.4605s), train_loss: 0.1716, val_loss: 0.2145, train_acc: 0.9397, val_acc:0.9219\n",
      "\t\ttrain_roc: 0.9795, val_roc: 0.9711, train_auprc: 0.9735, val_auprc: 0.9658\n",
      "scheduling\n",
      "Epoch: 108 (59.6880s), train_loss: 0.1710, val_loss: 0.2122, train_acc: 0.9396, val_acc:0.9233\n",
      "\t\ttrain_roc: 0.9797, val_roc: 0.9716, train_auprc: 0.9738, val_auprc: 0.9664\n",
      "scheduling\n",
      "Epoch: 109 (52.8420s), train_loss: 0.1713, val_loss: 0.2142, train_acc: 0.9397, val_acc:0.9212\n",
      "\t\ttrain_roc: 0.9794, val_roc: 0.9710, train_auprc: 0.9736, val_auprc: 0.9656\n",
      "scheduling\n",
      "Epoch: 110 (51.7014s), train_loss: 0.1695, val_loss: 0.2166, train_acc: 0.9405, val_acc:0.9212\n",
      "\t\ttrain_roc: 0.9799, val_roc: 0.9704, train_auprc: 0.9742, val_auprc: 0.9649\n",
      "scheduling\n",
      "Epoch: 111 (55.5486s), train_loss: 0.1692, val_loss: 0.2138, train_acc: 0.9407, val_acc:0.9214\n",
      "\t\ttrain_roc: 0.9801, val_roc: 0.9711, train_auprc: 0.9745, val_auprc: 0.9658\n",
      "scheduling\n",
      "Epoch: 112 (59.1217s), train_loss: 0.1691, val_loss: 0.2131, train_acc: 0.9406, val_acc:0.9224\n",
      "\t\ttrain_roc: 0.9799, val_roc: 0.9715, train_auprc: 0.9741, val_auprc: 0.9658\n",
      "scheduling\n",
      "Epoch: 113 (58.2297s), train_loss: 0.1692, val_loss: 0.2172, train_acc: 0.9405, val_acc:0.9208\n",
      "\t\ttrain_roc: 0.9799, val_roc: 0.9702, train_auprc: 0.9743, val_auprc: 0.9644\n",
      "scheduling\n",
      "Epoch: 114 (51.9224s), train_loss: 0.1694, val_loss: 0.2144, train_acc: 0.9402, val_acc:0.9222\n",
      "\t\ttrain_roc: 0.9799, val_roc: 0.9708, train_auprc: 0.9742, val_auprc: 0.9652\n",
      "scheduling\n",
      "Epoch: 115 (51.7920s), train_loss: 0.1692, val_loss: 0.2141, train_acc: 0.9407, val_acc:0.9228\n",
      "\t\ttrain_roc: 0.9798, val_roc: 0.9713, train_auprc: 0.9741, val_auprc: 0.9660\n",
      "scheduling\n",
      "Epoch: 116 (51.3280s), train_loss: 0.1690, val_loss: 0.2152, train_acc: 0.9406, val_acc:0.9218\n",
      "\t\ttrain_roc: 0.9800, val_roc: 0.9710, train_auprc: 0.9743, val_auprc: 0.9655\n",
      "scheduling\n",
      "Epoch: 117 (51.0223s), train_loss: 0.1705, val_loss: 0.2138, train_acc: 0.9403, val_acc:0.9222\n",
      "\t\ttrain_roc: 0.9796, val_roc: 0.9714, train_auprc: 0.9737, val_auprc: 0.9661\n",
      "scheduling\n",
      "Epoch: 118 (52.7407s), train_loss: 0.1686, val_loss: 0.2139, train_acc: 0.9411, val_acc:0.9216\n",
      "\t\ttrain_roc: 0.9801, val_roc: 0.9711, train_auprc: 0.9741, val_auprc: 0.9658\n",
      "scheduling\n",
      "Epoch: 119 (52.1970s), train_loss: 0.1684, val_loss: 0.2162, train_acc: 0.9416, val_acc:0.9207\n",
      "\t\ttrain_roc: 0.9801, val_roc: 0.9712, train_auprc: 0.9742, val_auprc: 0.9659\n",
      "scheduling\n",
      "Epoch: 120 (54.5979s), train_loss: 0.1664, val_loss: 0.2154, train_acc: 0.9418, val_acc:0.9218\n",
      "\t\ttrain_roc: 0.9806, val_roc: 0.9707, train_auprc: 0.9750, val_auprc: 0.9649\n"
     ]
    }
   ],
   "source": [
    "train(model, train_data_loader, val_data_loader, loss, optimizer, n_epochs, device, scheduler)"
   ]
  }
 ]
}